#!/usr/bin/env python3

import math
from typing import Type

import torch
import torch.nn as nn


class ModernHopfieldAttention(nn.Module):

    def __init__(
        self,
        dim: int,
        num_heads: int,
        attn_alpha: float,
        skip_alpha: float,
        causal: bool,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        norm_layer: Type[nn.Module] = nn.LayerNorm,
        num_tokens: float = 1024,
    ) -> None:
        """
        MHA (Modern Hopfield Attention) class.

        Args:
            dim (int): Embedding dimension of tokens.
            num_heads (int): Number of attention heads.
            attn_alpha (float): A scaling factor that controls how much the residual connection blends the attention logits from the previous and current layers. A larger value gives more weight to the previous information. The value must be between 0 and 1.
            skip_alpha (float): A scaling factor that controls how much the residual connection blends the attention input and attention output. This skip connection is valid only if `pre_connect` is True. The value must be between 0 and 1.
            pre_connect (bool): A flag that enables skip connection before the projection.
            first_sa (bool): A flag that controls the residual connection of attention. If this flag is True, the first attention layer will be self-attention.
            causal (bool): A flag that enables autoregressive causal masking.
            qkv_bias (bool, optional): A flag that enables the bias term in `nn.Linear` when creating QKV. Defaults to False.
            qk_norm (bool, optional): A flag that enables the normalization function for Query and Key. Defaults to False.
            attn_drop (float, optional): The probability of applying `nn.Dropout` to the attention. Defaults to 0.0.
            proj_drop (float, optional): The probability of applying `nn.Dropout` to the projected attention output. Defaults to 0.0.
            norm_layer (Type[nn.Module], optional): A normalization function applied to Query and Key before computing attention. This function will be used if `qk_norm` is True. Defaults to `nn.LayerNorm`.

        """
        super().__init__()
        # assertion
        assert dim % num_heads == 0

        # args
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = 1 / math.sqrt(self.head_dim)
        self.attn_alpha = attn_alpha
        self.skip_alpha = skip_alpha
        self.causal = causal

        # layers
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        ## causal
        if self.causal:
            self.num_tokens = num_tokens
            self.register_buffer(
                "bias",
                torch.tril(torch.ones(self.num_tokens, self.num_tokens)).view(
                    1, 1, self.num_tokens, self.num_tokens
                ),
            )

    def forward_vanilla(
        self, x: torch.Tensor, h: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        B, n, d = x.size()

        # pre connection
        residual = x.clone()

        # qkv-projection
        q, k, v = self.qkv(x).split(d, dim=-1)
        q = q.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        k = k.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        v = v.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        # normalize
        q, k = self.q_norm(q), self.k_norm(k)

        # attention-logit
        q = q * self.scale
        attn = torch.einsum("Bhnd,Bhod->Bhno", q, k)

        # logit-connection
        if h is not None:
            attn = (h * self.attn_alpha) + (attn * (1 - self.attn_alpha))

        h = attn.clone()

        # softmax
        attn = torch.softmax(attn, -1)
        attn = self.attn_drop(attn)

        # qkv
        x = torch.einsum("Bhno,Bhod->Bhnd", attn, v)

        # projection
        x = x.transpose(1, 2).contiguous().view(B, n, d)

        x = ((1 - self.skip_alpha) * x) + ((self.skip_alpha) * residual)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x, h

    # TODO
    def forward_causal(
        self, x: torch.Tensor, h: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        B, n, d = x.size()

        # pre connection
        residual = x.clone()

        # qkv-projection
        q, k, v = self.qkv(x).split(d, dim=-1)
        q = q.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        k = k.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
        v = v.view(B, n, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

        # normalize
        q, k = self.q_norm(q), self.k_norm(k)

        # calc attention-logit
        q = q * self.scale
        attn = torch.einsum("Bhnd,Bhod->Bhno", q, k)

        # logit-connection
        if h is not None:
            attn = (h * self.attn_alpha) + (attn * (1 - self.attn_alpha))

        h = attn.clone()

        # attention masking
        attn = attn.masked_fill(self.bias[:, :, :n, :n] == 0, float("-inf"))

        # softmax
        attn = torch.softmax(attn, -1)
        attn = self.attn_drop(attn)

        # qkv
        x = torch.einsum("Bhno,Bhod->Bhnd", attn, v)

        # projection
        x = x.transpose(1, 2).contiguous().view(B, n, d)

        x = ((1 - self.skip_alpha) * x) + ((self.skip_alpha) * residual)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x, h

    def forward(
        self, x: torch.Tensor, h: torch.Tensor | None = None
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """pytorch forward function

        Args:
            x (torch.Tensor): Input tensor.
            h (torch.Tensor | None, optional): attention logits of previous layer. Defaults to None.

        Returns:
            tuple[torch.Tensor, torch.Tensor]:
        """

        if self.causal:
            x, h = self.forward_causal(x, h)
        else:
            x, h = self.forward_vanilla(x, h)

        return x, h
